{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning import pytorch as pl\n",
    "import numpy as np\n",
    "from numpy.typing import ArrayLike\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torchmetrics\n",
    "import logging\n",
    "\n",
    "from chemprop import data, models, nn\n",
    "from chemprop.nn.metrics import ChempropMetric, MetricRegistry\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Available metric functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop provides several metrics. The functions calculate a single value that serves as a measure of model performance. Users only need to select the metric(s) to use. The rest of the details are handled by Chemprop and the lightning trainer, which logs all metric values to the trainer logger (defaults to TensorBoard) for the validation and test sets. Note that the validation metrics are in the scaled space while the test metrics are in the original target space.\n",
    "\n",
    "See also [loss functions](./loss_functions.ipynb) which are the same as metrics, except used to optimize the model and therefore required to be differentiable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mse\n",
      "mae\n",
      "rmse\n",
      "bounded-mse\n",
      "bounded-mae\n",
      "bounded-rmse\n",
      "r2\n",
      "binary-mcc\n",
      "multiclass-mcc\n",
      "roc\n",
      "prc\n",
      "accuracy\n",
      "f1\n"
     ]
    }
   ],
   "source": [
    "for metric in MetricRegistry:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Specifying metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each FFN predictor has a default metric. If you want different metrics reported, you can give a list of metrics to the model at creation. Note that the list of metrics is used in place of the default metric and not in addition to the default metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import MSE, MAE, RMSE\n",
    "\n",
    "metrics = [MSE(), MAE(), RMSE()]\n",
    "model = models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN(), metrics=metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accumulating metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop metrics are based on `Metric` from `torchmetrics` which stores the information from each batch that is needed to calculate the metric over the whole validation or test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0:   0%|                                                                                                                      | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brianli/Documents/chemprop/chemprop/nn/message_passing/base.py:263: UserWarning: The operator 'aten::scatter_reduce.two_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
      "  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.27it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.4941912293434143     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mse          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.3071698546409607     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.5542290806770325     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.4941912293434143    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mse         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.3071698546409607    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.5542290806770325    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 147.05it/s]\n"
     ]
    }
   ],
   "source": [
    "smis = [\"C\" * i for i in range(1, 11)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "dset = data.MoleculeDataset([data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n",
    "dataloader = data.build_dataloader(dset, shuffle=False, batch_size=2)\n",
    "\n",
    "trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)\n",
    "result_when_batched = trainer.test(model, dataloader)\n",
    "preds = trainer.predict(model, dataloader)\n",
    "preds = torch.concat(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch / Not Batched\n",
      "0.5542 / 0.5542\n"
     ]
    }
   ],
   "source": [
    "result_when_not_batched = RMSE()(preds, torch.from_numpy(dset.Y), None, None, None, None)\n",
    "print(\"Batch / Not Batched\")\n",
    "print(f\"{result_when_batched[0]['test/rmse']:.4f} / {result_when_not_batched.item():.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batch normalization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is worth noting that if your model has a batch normalization layer, the computed metric will be different depending on if the model is in training or evaluation mode. When a batch normalization layer is training, it uses a biased estimator to calculate the standard deviation, but the value stored and used during evaluation is calculated with the unbiased estimator. Lightning takes care of this if the `Trainer()` is used. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are several metric options for regression. `MSE` is the default. There are also bounded versions (except for r2), similar to the bounded versions of the [loss functions](./loss_functions.ipynb). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import MSE, MAE, RMSE, R2Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import BoundedMAE, BoundedMSE, BoundedRMSE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are metrics for both binary and multiclass classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import (\n",
    "    BinaryAUROC,\n",
    "    BinaryAUPRC,\n",
    "    BinaryAccuracy,\n",
    "    BinaryF1Score,\n",
    "    BinaryMCCMetric,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import MulticlassMCCMetric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Spectra"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Spectral information divergence and wasserstein (earthmovers distance) are often used for spectral predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn.metrics import SID, Wasserstein"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chemprop metrics are instances of `chemprop.nn.metrics.ChempropMetric`, which inherits from `torchmetrics.Metric`. Custom loss functions need to follow the interface of both `ChempropMetric` and `Metric`. Start with a `Metric` either by importing an existing one from `torchmetrics` or by creating your own by following the instructions on the `torchmetrics` website. Then make the following changes:\n",
    "\n",
    "1. Allow for task weights to be passed to the `__init__` method.\n",
    "2. Allow for the `update` method to be given `preds, targets, mask, weights, lt_mask, gt_mask` in that order.\n",
    "3. Provide an alias property, which is used to identify the metric value in the logs.\n",
    "\n",
    "* `preds`: A `Tensor` of the model's predictions with dimension 0 being the batch dimension and dimension 1 being the task dimension.\n",
    "* `targets`: A `Tensor` of the target values with dimension 0 being the batch dimension and dimension 1 being the task dimension.\n",
    "* `mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is present and finite and `False` where it is not.\n",
    "* `weights`: Usually ignored in metrics.\n",
    "* `lt_mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is a \"less than\" target value and `False` where it is not.\n",
    "* `gt_mask`: A `Tensor` of the same shape as `targets` with `True`s where the target value is a \"greater than\" target value and `False` where it is not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChempropMulticlassAUROC(torchmetrics.classification.MulticlassAUROC):\n",
    "    def __init__(self, task_weights: ArrayLike = 1.0, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1)\n",
    "        if (self.task_weights != 1.0).any():\n",
    "            logger.warn(\"task_weights were provided but are ignored by metric \"\n",
    "                          f\"{self.__class__.__name__}. Got {task_weights}\")\n",
    "\n",
    "    def update(self, preds: Tensor, targets: Tensor, mask: Tensor | None = None, *args, **kwargs):\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(targets, dtype=torch.bool)\n",
    "\n",
    "        super().update(preds[mask], targets[mask].long())\n",
    "\n",
    "    @property\n",
    "    def alias(self) -> str:\n",
    "        return \"multiclass_auroc\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, if your metric can return a value for every task for every data point (i.e. not reduced in the task or batch dimension), you can inherit from `chemprop.nn.metrics.ChempropMetric` and just override the `_calc_unreduced_loss` method (and if needed the `__init__` method)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BoundedNormalizedMSEPlus1(ChempropMetric):\n",
    "    def __init__(self, task_weights = None, norm: float = 1.0):\n",
    "        super().__init__(task_weights)\n",
    "        norm = torch.as_tensor(norm)\n",
    "        self.register_buffer(\"norm\", norm)\n",
    "\n",
    "    def _calc_unreduced_loss(self, preds, targets, mask, weights, lt_mask, gt_mask) -> Tensor:\n",
    "        preds = torch.where((preds < targets) & lt_mask, targets, preds)\n",
    "        preds = torch.where((preds > targets) & gt_mask, targets, preds)\n",
    "\n",
    "        return torch.sum((preds - targets) ** 2) / self.norm + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parents[3]\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"classification\" / \"mol_multiclass.csv\"\n",
    "df_input = pd.read_csv(input_path)\n",
    "smis = df_input.loc[:, \"smiles\"].values\n",
    "ys = df_input.loc[:, [\"activity\"]].values\n",
    "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(all_data, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")\n",
    "train_dset = data.MoleculeDataset(train_data[0])\n",
    "val_dset = data.MoleculeDataset(val_data[0])\n",
    "test_dset = data.MoleculeDataset(test_data[0])\n",
    "train_loader = data.build_dataloader(train_dset)\n",
    "val_loader = data.build_dataloader(val_dset, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make a model with a custom loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes = max(ys).item() + 1\n",
    "\n",
    "metrics = [ChempropMulticlassAUROC(num_classes = n_classes)]\n",
    "\n",
    "model = models.MPNN(\n",
    "    nn.BondMessagePassing(), \n",
    "    nn.NormAggregation(), \n",
    "    nn.MulticlassClassificationFFN(n_classes = n_classes), \n",
    "    metrics = metrics\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "\n",
      "  | Name            | Type                        | Params | Mode \n",
      "------------------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing          | 227 K  | train\n",
      "1 | agg             | NormAggregation             | 0      | train\n",
      "2 | bn              | Identity                    | 0      | train\n",
      "3 | predictor       | MulticlassClassificationFFN | 91.2 K | train\n",
      "4 | X_d_transform   | Identity                    | 0      | train\n",
      "5 | metrics         | ModuleList                  | 0      | train\n",
      "------------------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.276     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|                                                                                                              | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "/var/folders/hx/9k64y_d101zg5t41l44wfs_w0000gn/T/ipykernel_3641/67158073.py:13: UserWarning: MPS: nonzero op is supported natively starting from macOS 13.0. Falling back on CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/Indexing.mm:335.)\n",
      "  super().update(preds[mask], targets[mask].long())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.29it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: No positive samples in targets, true positive value should be meaningless. Returning zero tensor in true positive score\n",
      "  warnings.warn(*args, **kwargs)  # noqa: B028\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.52it/s, v_num=2, train_loss_step=1.030]\n",
      "Validation: |                                                                                                                                    | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                                                                                                                | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                                                                                                                   | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.06it/s]\u001b[A\n",
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s, v_num=2, train_loss_step=0.278, val_loss=0.985, train_loss_epoch=1.070]\u001b[A\n",
      "Validation: |                                                                                                                                    | 0/? [00:00<?, ?it/s]\u001b[A\n",
      "Validation:   0%|                                                                                                                                | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0:   0%|                                                                                                                   | 0/1 [00:00<?, ?it/s]\u001b[A\n",
      "Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.15it/s]\u001b[A\n",
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.72it/s, v_num=2, train_loss_step=0.278, val_loss=0.784, train_loss_epoch=0.756]\u001b[A"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|███████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.25it/s, v_num=2, train_loss_step=0.278, val_loss=0.784, train_loss_epoch=0.756]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.80it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">   test/multiclass_auroc   </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6266666650772095     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m  test/multiclass_auroc  \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6266666650772095    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[{'test/multiclass_auroc': 0.6266666650772095}]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = pl.Trainer(max_epochs=2)\n",
    "trainer.fit(model, train_loader, val_loader)\n",
    "trainer.test(model, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
